import torch
import pandas as pd
from molfeat.trans.pretrained.hf_transformers import PretrainedHFTransformer  # type: ignore
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Dataset
from selfies import encoder as selfies_encoder
from tqdm import tqdm
import numpy as np
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
import argparse
# Ensure these local imports are correct in your project structure
# from .predictor.tuning import FineTuningModel,FineTuningModelv2
from transformers.adapters import AdapterConfig, LoRAConfig
from transformers import AdapterType, PrefixTuningConfig, PromptEncoderConfig, PPOTrainer
from torch.nn import Parameter


# Mocking the local import for standalone execution
class FineTuningModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)


class FineTuningModelv2(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))


class Trainer:
    def __init__(self, data, kind='ChemGPT-19M', notation='selfies', target='measured_log_sol', shuffle=True,
                 name='Benchmark', lora=None,
                 # == CatalystLLM Paper Parameters ==
                 lr=5.0e-5,
                 batch_size=8,  # Per-card batch size from paper
                 epochs=2,  # Paper used 1.06 epochs, we set to 2 and use max_steps
                 max_steps=85255,  # Total training steps from paper
                 weight_decay=0.01,  # Common default for AdamW, paper does not specify
                 r=8,  # LoRA rank
                 alpha=16,  # LoRA alpha (scale factor)
                 dropout=0.1,  # LoRA dropout
                 neftune_noise_alpha=5,  # NEFTune noise regularization parameter
                 compute_dtype='bf16',  # Use Brain Float 16 for mixed precision
                 # == Original/Other Parameters ==
                 gpu=False, loss='MAE', hidden_dim=256, quantization_bits=4, nn='v1', clip_grad=True,
                 adapter_type="houlsby", reduction_factor=16, encoder_only=False,
                 num_virtual_tokens=10, encoder_hidden_size=512, num_trainable_layers=2,
                 intervention_dim=64, sparsity=0.99, reward_model=None, preference_dataset=None):

        self.datafile = data
        self.kind = kind
        self.notation = notation
        self.target = target
        self.shuffle = shuffle
        self.name = name
        self.lora = lora
        self.nn = nn
        self.loss = loss
        self.clip_gradient = clip_grad
        self.global_step = 0

        # === Set hyperparameters based on CatalystLLM specs ===
        self.best_lr = lr
        self.best_batch_size = batch_size
        self.best_epochs = int(np.ceil(epochs))  # Ensure epochs is an integer for range()
        self.max_steps = max_steps
        self.best_weight_decay = weight_decay
        self.best_r = r
        self.best_alpha = alpha
        self.best_dropout = dropout
        self.neftune_noise_alpha = neftune_noise_alpha
        self.compute_dtype = torch.bfloat16 if compute_dtype == 'bf16' and torch.cuda.is_bf16_supported() else torch.float32
        if self.compute_dtype == torch.float32:
            print("Warning: bfloat16 is not supported on this device. Falling back to float32.")

        # Store other PEFT method parameters
        self.quantization_bits = quantization_bits
        self.hidden_dim = hidden_dim
        self.adapter_type = adapter_type
        self.reduction_factor = reduction_factor
        self.encoder_only = encoder_only
        self.num_virtual_tokens = num_virtual_tokens
        self.encoder_hidden_size = encoder_hidden_size
        self.num_trainable_layers = num_trainable_layers
        self.sparsity = sparsity
        self.intervention_dim = intervention_dim
        self.reward_model = reward_model
        self.preference_data = preference_dataset

        print('Loading parameters:')
        steps = [
            ("Loading data", self.loaddata),
            ("Loading LLM", self.loadLLM),
            ("Loading Neural Network", self.loadNN),
            ("Setting up DataLoader", self.setupLoader),
            ("Loading to GPU", self.load_to_gpu)
        ]

        with tqdm(total=len(steps), desc="Initialization Progress") as pbar:
            for step_name, step_function in steps:
                pbar.set_description(f"Executing: {step_name}")
                step_function()
                pbar.update(1)

    def loaddata(self):
        data = pd.read_csv(self.datafile)
        self.train_data = data[data['split'] == 'train']
        self.val_data = data[data['split'] == 'val']
        self.test_data = data[data['split'] == 'test']

    def prepare_data(self, data, target='measured_log_sol'):
        # Code remains unchanged
        if 'smiles' in data.columns:
            if self.notation == 'smiles':
                smiles = data['smiles'].tolist()
            elif self.notation == 'selfies':
                if 'selfies' in data.columns:
                    smiles = data['selfies'].tolist()
                else:
                    smiles = [selfies_encoder(s) for s in data['smiles'].tolist()]

        targets = data[self.target].values
        inputs = self.transformer.featurizer.tokenizer(smiles, truncation=True, padding=True, return_tensors="pt")
        y = torch.tensor(targets, dtype=torch.float32).unsqueeze(1)
        dataset = TensorDataset(inputs["input_ids"], inputs["attention_mask"], y)
        return dataset

    def loadLLM(self):
        # Load model with specified compute dtype for mixed-precision training
        # Note: The custom `PretrainedHFTransformer` must be able to pass `torch_dtype`
        # to the underlying Hugging Face `from_pretrained` call.
        self.transformer = PretrainedHFTransformer(kind=self.kind, notation=self.notation, dtype=self.compute_dtype,
                                                   preload=True, torch_dtype=self.compute_dtype)

        self.train_dataset = self.prepare_data(self.train_data)
        self.val_dataset = self.prepare_data(self.val_data)
        self.test_dataset = self.prepare_data(self.test_data)

    def load_to_gpu(self):
        # Code remains largely unchanged, just respects the specified device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        # Move models to device
        self.model.to(self.device)
        self.transformer.featurizer.model.to(self.device)

        # Move loss function to device
        self.criterion.to(self.device)

        # Note: Data is moved to GPU batch-by-batch during training to save VRAM.
        # Pre-loading entire datasets to GPU is only feasible for smaller datasets.

    def lossfcn(self):
        # Code remains unchanged
        if self.loss == 'MAE':
            self.criterion = nn.L1Loss()
        elif self.loss == 'MSE':
            self.criterion = nn.MSELoss()
        else:
            self.criterion = nn.SmoothL1Loss()  # Default to Huber Loss

    def loadNN(self):
        # Dynamically get input dimension from the transformer's config
        self.input_dim = self.transformer.featurizer.model.config.hidden_size
        self.output_dim = 1

        if self.nn == 'v1':
            self.model = FineTuningModel(input_dim=self.input_dim, output_dim=self.output_dim)
        elif self.nn == 'v2':
            self.model = FineTuningModelv2(input_dim=self.input_dim, hidden_dim=self.hidden_dim,
                                           output_dim=self.output_dim)

        self.lossfcn()

        # Optimizer for the regression head model
        # The paper uses AdamW for the main fine-tuning. We apply it to the head as well.
        self.optimizer = optim.AdamW(self.model.parameters(), lr=self.best_lr, weight_decay=self.best_weight_decay,
                                     betas=(0.99, 0.999), eps=1e-8)

    def setupLoader(self):
        # Code remains unchanged
        self.train_dataloader = DataLoader(self.train_dataset, batch_size=self.best_batch_size, shuffle=self.shuffle)
        self.val_dataloader = DataLoader(self.val_dataset, batch_size=self.best_batch_size, shuffle=False)
        self.test_dataloader = DataLoader(self.test_dataset, batch_size=self.best_batch_size, shuffle=False)

    def find_all_linear_names(self, model):
        """Finds all linear layer names in a model for LoRA targeting."""
        lora_module_names = set()
        for name, module in model.named_modules():
            if isinstance(module, nn.Linear):
                # Add the final module name in the path
                module_name = name.split('.')[-1]
                lora_module_names.add(module_name)

        # A common target setup for many transformer models
        if 'query' in lora_module_names and 'key' in lora_module_names and 'value' in lora_module_names:
            return ['query', 'key', 'value']

        return list(lora_module_names)

    def FineTuneWithLoRA(self):
        # Find all linear modules to apply LoRA to, as specified in the paper
        target_modules = self.find_all_linear_names(self.transformer.featurizer.model)
        print(f"Applying LoRA to the following modules: {target_modules}")

        lora_config = LoraConfig(
            task_type=TaskType.FEATURE_EXTRACTION,
            r=self.best_r,
            lora_alpha=self.best_alpha,
            lora_dropout=self.best_dropout,
            target_modules=target_modules,
            bias="none",
        )

        self.transformer.featurizer.model = get_peft_model(self.transformer.featurizer.model, lora_config)
        self.transformer.featurizer.model.print_trainable_parameters()
        self.transformer.featurizer.model.train()

        # Use AdamW optimizer with specified betas and epsilon from the paper
        lora_params = [p for p in self.transformer.featurizer.model.parameters() if p.requires_grad]
        self.lora_optimizer = optim.AdamW(lora_params, lr=self.best_lr, weight_decay=self.best_weight_decay,
                                          betas=(0.99, 0.999), eps=1e-8)

    # ... Other PEFT methods like QLoRA, BitFit, etc. remain unchanged ...
    # (Code for other PEFT methods is omitted for brevity but would be here in the full file)

    def _create_encoder(self, batch):
        batch_input_id, batch_mask, batch_y = batch
        encoder = {
            "input_ids": batch_input_id.to(self.device),
            "attention_mask": batch_mask.to(self.device)
        }
        return encoder, batch_y.to(self.device)

    def _get_embed(self, encoder):
        is_training = self.transformer.featurizer.model.training
        with torch.set_grad_enabled(is_training):
            # Get hidden states from the base model
            outputs = self.transformer.featurizer.model(output_hidden_states=True, **encoder)
            embeddings = outputs.last_hidden_state

            # === NEFTune Noise Injection ===
            # Add noise to embeddings only during training
            if is_training and self.neftune_noise_alpha is not None and self.neftune_noise_alpha > 0:
                # Calculate the norm of the embeddings and the noise
                mag_norm = embeddings.norm(p=2, dim=-1, keepdim=True)
                noise = torch.randn_like(embeddings, device=self.device)
                noise_norm = noise.norm(p=2, dim=-1, keepdim=True)

                # Scale noise to have the same magnitude as embeddings
                scaled_noise = noise / (noise_norm + 1e-7) * mag_norm

                # Add scaled noise to embeddings
                embeddings = embeddings + self.neftune_noise_alpha * scaled_noise

            # Use mean pooling to get a single vector for the sequence
            final_embeddings = embeddings.mean(dim=1)

        return final_embeddings.to(self.device)

    def clip_grad(self):
        # Clip gradients for both the head and the transformer
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        if self.lora is not None:
            torch.nn.utils.clip_grad_norm_(self.transformer.featurizer.model.parameters(), max_norm=1.0)

    def trainwithLoRA(self, epoch):
        epoch_loss = 0
        pbar = tqdm(self.train_dataloader, desc=f"LoRA Epoch {epoch + 1}/{self.best_epochs}", leave=False)
        for batch in pbar:
            # If max steps reached, stop this epoch
            if self.max_steps > 0 and self.global_step >= self.max_steps:
                break

            encoder, batch_y = self._create_encoder(batch)

            # Zero gradients
            self.lora_optimizer.zero_grad()
            self.optimizer.zero_grad()

            # Forward pass
            with torch.autocast(device_type=self.device.type, dtype=self.compute_dtype):
                outputs = self._get_embed(encoder)
                predictions = self.model(outputs)
                loss = self.criterion(predictions, batch_y)

            # Backward pass and optimization
            loss.backward()
            if self.clip_gradient: self.clip_grad()
            self.optimizer.step()
            self.lora_optimizer.step()

            epoch_loss += loss.item()
            self.global_step += 1
            pbar.set_postfix(loss=loss.item())

        avg_epoch_loss = epoch_loss / len(self.train_dataloader) if len(self.train_dataloader) > 0 else 0
        print(f"Epoch [{epoch + 1}/{self.best_epochs}], Avg Train Loss: {avg_epoch_loss:.4f}")

    def evalwithLoRA(self, loader, desc="Validation"):
        val_loss = 0
        self.model.eval()
        self.transformer.featurizer.model.eval()

        with torch.no_grad():
            for batch in tqdm(loader, desc=desc, leave=False):
                encoder, batch_y = self._create_encoder(batch)
                with torch.autocast(device_type=self.device.type, dtype=self.compute_dtype):
                    outputs = self._get_embed(encoder)
                    predictions = self.model(outputs)
                    loss = self.criterion(predictions, batch_y)
                val_loss += loss.item()

        return val_loss / len(loader) if len(loader) > 0 else 0

    def _tunestep(self):
        if self.lora == 'lora':
            print("Fine-tuning with LoRA...")
            self.FineTuneWithLoRA()
        # ... other PEFT method initializations ...

    def train(self):
        self.train_losses = []
        self.val_losses = []
        best_val_loss = np.inf

        self._tunestep()

        for epoch in range(self.best_epochs):
            print(f"\n--- Starting Epoch {epoch + 1}/{self.best_epochs} ---")

            # Set models to training mode
            self.model.train()
            if self.lora: self.transformer.featurizer.model.train()

            if self.lora in ['lora', 'qlora']:
                self.trainwithLoRA(epoch)
            else:
                # Placeholder for non-LoRA training or other PEFT methods
                print("Training method for specified `lora` type not fully implemented.")
                break

            # Evaluation
            avg_val_loss = self.evalwithLoRA(self.val_dataloader, desc="Validation")
            print(f"Epoch {epoch + 1} Validation Loss: {avg_val_loss:.4f}")
            self.val_losses.append(avg_val_loss)

            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                print(f"New best validation loss: {best_val_loss:.4f}. Saving model...")
                torch.save(self.model.state_dict(), f'{self.name}.pth')
                if self.lora:
                    self.transformer.featurizer.model.save_pretrained(f'{self.name}_lora_adapters')

            # Check if max_steps has been reached
            if self.max_steps > 0 and self.global_step >= self.max_steps:
                print(f"Max steps ({self.max_steps}) reached. Halting training.")
                break

        print("\n--- Training Finished ---")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a model using the CatalystLLM strategy.")
    # --- Key arguments ---
    parser.add_argument("--data", type=str, required=True, help="Path to the dataset CSV file.")
    parser.add_argument("--name", type=str, default="CatalystLLM_run", help="Name for saving the model and results.")
    parser.add_argument("--lora", type=str, default='lora', choices=['lora', 'qlora', 'adaptive', 'prefix', 'none'],
                        help="PEFT method to use.")

    # --- Hyperparameters from paper ---
    parser.add_argument("--lr", type=float, default=5.0e-5, help="Learning rate.")
    parser.add_argument("--batch_size", type=int, default=8, help="Per-device batch size.")
    parser.add_argument("--max_steps", type=int, default=85255, help="Total number of training steps.")
    parser.add_argument("--r", type=int, default=8, help="LoRA rank.")
    parser.add_argument("--alpha", type=int, default=16, help="LoRA alpha.")
    parser.add_argument("--dropout", type=float, default=0.1, help="LoRA dropout.")
    parser.add_argument("--neftune_noise_alpha", type=float, default=5, help="NEFTune noise alpha.")
    parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")

    # --- Other general arguments ---
    parser.add_argument("--kind", type=str, default="ChemGPT-19M", help="Type of pretrained model to use.")
    parser.add_argument("--target", type=str, default="measured_log_sol", help="Target column in the dataset.")
    parser.add_argument("--epochs", type=int, default=2,
                        help="Number of epochs (training will stop early if max_steps is reached).")

    args = parser.parse_args()

    trainer = Trainer(
        data=args.data,
        name=args.name,
        lora=args.lora,
        lr=args.lr,
        batch_size=args.batch_size,
        max_steps=args.max_steps,
        epochs=args.epochs,
        r=args.r,
        alpha=args.alpha,
        dropout=args.dropout,
        neftune_noise_alpha=args.neftune_noise_alpha,
        weight_decay=args.weight_decay,
        kind=args.kind,
        target=args.target,
    )

    trainer.train()
